% Get and do inference for individual specific treatment effects in JTPA data

%% Organize data
% Original basis
createJTPAdata ;
clearvars -except X ;
XOLD = X;
clearvars -except XOLD ;

% Load data and create interactions
createJTPAdatav4 ;
X = [X XOLD] ;

% Keep track of some sample size numbers
[n,px] = size(X);

% Big design matrix will all treatment interactions
Z = [ones(n,1) , X , d , (d*ones(1,px)).*X];

% Subset of treatment observations
yd1 = y(d == 1,:);
Xd1 = X(d == 1,:);

% Subset of control observations
yd0 = y(d == 0,:);
Xd0 = X(d == 0,:);

%% OLS with everything
%{
bolsJTPA = Z\y;
[solsJTPA,VolsJTPA] = hetero_se(Z,y-Z*bolsJTPA,pinv(Z'*Z));
%}
%% Lasso with plug-in tuning
% Control observations
% Initialize with 5 most marginally correlated variables
czy0 = corr(yd0,Xd0);
czys = sort(-abs(czy0));
suppy0 = find(abs(czy0) >= -czys(5));
tmpy0 = (Xd0(:,suppy0)-ones(sum(d==0),1)*mean(Xd0(:,suppy0)))\(yd0-mean(yd0));
by0 = zeros(size(Xd0,2),1);
by0(suppy0) = tmpy0;
blas0 = feasiblePostLasso(yd0-mean(yd0),Xd0-ones(sum(d==0),1)*mean(Xd0),...
    'MaxIter',1,'beta0',by0);
blas0 = [(mean(yd0)-mean(Xd0)*blas0);blas0];
use0 = blas0 ~= 0;

% Treatment observations
czy1 = corr(yd1,Xd1);
czys = sort(-abs(czy1));
suppy1 = find(abs(czy1) >= -czys(5));
tmpy1 = (Xd1(:,suppy1)-ones(sum(d),1)*mean(Xd1(:,suppy1)))\(yd1-mean(yd1));
by1 = zeros(size(Xd1,2),1);
by1(suppy1) = tmpy1;
blas1 = feasiblePostLasso(yd1-mean(yd1),Xd1-ones(sum(d),1)*mean(Xd1),...
    'MaxIter',1,'beta0',by1);
blas1 = [(mean(yd1)-mean(Xd1)*blas1);blas1];
use1 = blas1 ~= 0;

% Post-lasso
usepl = union(find(use0),find(use1));
usepl = [usepl ; (usepl+px+1)];
Zpl = Z(:,usepl);
bpl = Zpl\y;
[sepl,Vpl] = hetero_se(Zpl,y-Zpl*bpl,pinv(Zpl'*Zpl));

bplJTPA = sparse(2*px+2,1);
bplJTPA(usepl) = bpl;
splJTPA = sparse(2*px+2,1);
splJTPA(usepl) = sepl;
VplJTPA = sparse(2*px+2,2*px+2);
VplJTPA(usepl,usepl) = Vpl;

%% Look at Wald test of all variables interacted with treatment have 0 coefficients
testset = ((px+3):(2*px+2))';
r = zeros(size(bplJTPA));

% % OLS
% WOLS = (bolsJTPA(testset) - r(testset))'*(VolsJTPA(testset,testset)\(bolsJTPA(testset) - r(testset)));
% pvalOLS = 1-chi2cdf(WOLS,numel(testset));

% Post-lasso
testpl = intersect(testset,usepl);
WPL = (bplJTPA(testpl) - r(testpl))'*(VplJTPA(testpl,testpl)\(bplJTPA(testpl) - r(testpl)));
pvalPL = 1-chi2cdf(WPL,numel(testpl));

% Forward selection
[FSpval,FSW,FSdf] = fsel_wald_big(Z,y,testset,r,usepl,10);


%% Conditional average treatment effects

Xte = [sparse(n,(px+1)) , ones(n,1) , X];
%{
% OLS estimates
teolsJTPA = Xte*bolsJTPA;
steolsJTPA = sqrt(diag(Xte*VolsJTPA*Xte'));
%}
% Post-lasso estimates
teplJTPA = Xte*bplJTPA;
steplJTPA = sqrt(diag(Xte*VplJTPA*Xte'));

% TU estimates

% There are not n unique values of X.  Just compute TU for one of each
% type.
nT = numel(VUnique);

XteUnique = zeros(nT,size(Xte,2));
for jj = 1:nT
    indjj = find(V == VUnique(jj));
    XteUnique(jj,:) = Xte(indjj(1),:);
end

save JTPA_CATE_UNION_INIT;

%% Takes too long, so this part is getting brute force parallelized
% % Precompute crossproduct matrices
% MZZ = Z'*Z;  
% MZY = Z'*y;  
% 
% TUteJTPAlb = zeros(nT,10);
% TUteJTPAub = zeros(nT,10);
% parfor ii = 1:nT
%     fprintf('Iteration: %d \n',ii);
%     [TUteJTPAlb(ii,:),TUteJTPAub(ii,:)] = fsel_linear_big(Z,y,XteUnique(ii,:)',usepl,MZZ,MZY,10);    
% end
%    
% %% save output
% 
% save JTPA_CATE_HAD ;
%
% %% Make plot of estimated effects
% % Plots sorted by treatment effect estimate
% %{
% olsres = [teolsJTPA steolsJTPA];
% olsres = sortrows(olsres,1);
% %}
% plres = [teplJTPA steplJTPA];
% plres = sortrows(plres,1);
% 
% fsreslb = [teplJTPA TUteJTPAlb];
% fsreslb = sortrows(fsreslb,1);
% fsresub = [teplJTPA TUteJTPAub];
% fsresub = sortrows(fsresub,1);
% 
% %{
% figure; plot(1:n,olsres(:,1),'b',1:n,olsres(:,1)-1.96*olsres(:,2),'r',1:n,olsres(:,1)+1.96*olsres(:,2),'r');
% title('OLS Pointwise 95% CI for CATE(x)');
% savefig('JTPACATE_OLS');
% %}
% figure; plot(1:n,plres(:,1),'b',1:n,plres(:,1)-1.96*plres(:,2),'r',1:n,plres(:,1)+1.96*plres(:,2),'r');
% title('Post-Lasso Pointwise 95% CI for CATE(x)');
% savefig('JTPACATE_PL');
% 
% for ii = 1:10
%     figure; plot(1:n,plres(:,1),'b',1:n,fsreslb(:,1+ii),'r',1:n,fsresub(:,1+ii),'r');
%     title(sprintf('FS(%d) Pointwise 95% CI for CATE(x)',ii));
%     filename = sprintf('JTPACATE_FS%d',ii);
%     savefig(filename);
% end
% 
% 
% %{
% %% Treatment effects for a few individuals
% % Just look at first individual within each level of treatment effect
% unique_te = unique(teplJTPA);
% nte = numel(unique_te);
% 
% indlbs = zeros(nte,2+size(TUteJTPAlb,2));
% indubs = zeros(nte,2+size(TUteJTPAub,2));
% indpes = zeros(nte,2);
% for jj = 1:nte
%     indjj = find(teplJTPA == unique_te(jj));
%     indlbs(jj,:) = [teolsJTPA(indjj(1)) - 1.96*steolsJTPA(indjj(1)) , ...
%         teplJTPA(indjj(1)) - 1.96*steplJTPA(indjj(1)) , TUteJTPAlb(indjj(1),:)];
%     indubs(jj,:) = [teolsJTPA(indjj(1)) + 1.96*steolsJTPA(indjj(1)) , ...
%         teplJTPA(indjj(1)) + 1.96*steplJTPA(indjj(1)) , TUteJTPAub(indjj(1),:)];
%     indpes(jj,:) = [teolsJTPA(indjj(1)) , teplJTPA(indjj(1))];
% end
% 
% % Make plots for these individuals
% for jj = 1:nte
%     figure; plot((0:9),indlbs(jj,1)*ones(1,10),'r',(0:9),indubs(jj,1)*ones(1,10),'r');
%     hold on; plot((0:9),indpes(jj,1)*ones(1,10),'r--');
%     hold on; plot((0:9),indlbs(jj,2:end),'b',(0:9),indubs(jj,2:end),'b');
%     hold on; plot((0:9),indpes(jj,2)*ones(1,10),'b--');
%     hold on; plot((0:9),zeros(1,10),'k');
%     title(sprintf('95% Interval Estimate for CATE for Individual %d',jj));
%     xlabel('Number of Additional Variables');
%     ylabel('CATE');
%     filename = sprintf('JTPAind%d',jj);
%     saveas(gcf, filename, 'jpeg');
% end
% 
% %% Look at Wald test of all variables interacted with treatment have 0 coefficients
% testset = ((px+3):(2*px+2))';
% r = zeros(size(bolsJTPA));
% 
% % OLS
% WOLS = (bolsJTPA(testset) - r(testset))'*(VolsJTPA(testset,testset)\(bolsJTPA(testset) - r(testset)));
% pvalOLS = 1-chi2cdf(WOLS,numel(testset));
% 
% % Post-lasso
% testpl = intersect(testset,usepl);
% WPL = (bplJTPA(testpl) - r(testpl))'*(VplJTPA(testpl,testpl)\(bplJTPA(testpl) - r(testpl)));
% pvalPL = 1-chi2cdf(WPL,numel(testpl));
% 
% % Forward selection
% [FSpval,FSW,FSdf] = fsel_wald(Z,y,testset,r,usepl,ceil(log(n))+1);
% 
% %% Save results again
% save JTPA_CATE ;
% %}
